"""SSH key management."""

import typing

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from . import secrets
from . import token

MIN_KEY_SIZE = 4096


@token.register_token('ssh_private_key')
class SshPrivateKey(token.Token):
    """SSH private key."""

    clean_active_versions = 1

    def _create_token(self, token_version: str, meta: dict[str, typing.Any]) -> dict[str, str]:
        """Create an SSH key."""
        if (key_size := meta['key_size']) < MIN_KEY_SIZE:
            raise ValueError(f'Key size too small: {key_size} < {MIN_KEY_SIZE}')
        private_key = rsa.generate_private_key(65537, key_size)
        return {
            'private_key': private_key.private_bytes(
                serialization.Encoding.PEM,
                serialization.PrivateFormat.OpenSSH,
                serialization.NoEncryption()
            ).decode('ascii').strip(),
            'public_key': private_key.public_key().public_bytes(
                serialization.Encoding.OpenSSH,
                serialization.PublicFormat.OpenSSH
            ).decode('ascii').strip() + ' ' + meta['comment'],
        }

    def _destroy_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Destroy a token version."""
        # nothing to do to destroy an SSH key

    def _update_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Update a token version."""
        private_key = serialization.load_ssh_private_key(secrets.secret(
            f'{self.full_token_name(token_version)}:private_key').encode('ascii'), None)
        if key_size := getattr(private_key, 'key_size'):
            meta['key_size'] = key_size

    def _validate_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Validate a token version."""
        token_secret = secrets.secret(f'{self.full_token_name(token_version)}:')
        private_key = serialization.load_ssh_private_key(
            token_secret['private_key'].encode('ascii'), None)
        public_key = serialization.load_ssh_public_key(
            token_secret['public_key'].encode('ascii'), None)
        if private_key.public_key().public_bytes(
            serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
        ) != public_key.public_bytes(
            serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
        ):
            raise ValueError('Public key does not match private key')
